agentmux_srv\backend\wshutil/
wshrpc.rs1#![allow(dead_code)]
2use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18use std::sync::atomic::{AtomicBool, Ordering};
19use serde_json::Value;
20use tokio::sync::mpsc;
21use uuid::Uuid;
22
23use super::event::EventListener;
24use super::proxy::{RpcContext, RpcMessage};
25use super::osc::{DEFAULT_INPUT_CH_SIZE, DEFAULT_OUTPUT_CH_SIZE};
26
27pub const DEFAULT_TIMEOUT_MS: u64 = 5000;
29pub const RESP_CH_SIZE: usize = 32;
31
32#[derive(Debug, Clone)]
34pub struct RpcResponse {
35 pub data: Option<Value>,
36 pub error: Option<String>,
37 pub is_final: bool,
38}
39
40struct RpcData {
42 resp_tx: mpsc::Sender<RpcResponse>,
43}
44
45pub struct RpcResponseHandler {
47 pub req_id: String,
48 pub command: String,
49 pub data: Option<Value>,
50 pub rpc_context: RpcContext,
51 response_tx: mpsc::Sender<Vec<u8>>,
52 finalized: AtomicBool,
53}
54
55impl RpcResponseHandler {
56 pub fn get_command(&self) -> &str {
58 &self.command
59 }
60
61 pub fn get_command_raw_data(&self) -> Option<&Value> {
63 self.data.as_ref()
64 }
65
66 pub fn get_rpc_context(&self) -> &RpcContext {
68 &self.rpc_context
69 }
70
71 pub fn needs_response(&self) -> bool {
73 !self.req_id.is_empty()
74 }
75
76 pub async fn send_response(&self, data: Option<Value>, done: bool) -> Result<(), String> {
78 if !self.needs_response() {
79 return Ok(());
80 }
81
82 let msg = RpcMessage {
83 res_id: self.req_id.clone(),
84 data,
85 cont: !done,
86 ..Default::default()
87 };
88
89 let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
90 self.response_tx
91 .send(json)
92 .await
93 .map_err(|e| format!("send response: {}", e))?;
94
95 if done {
96 self.finalized.store(true, Ordering::SeqCst);
97 }
98 Ok(())
99 }
100
101 pub async fn send_response_error(&self, err: &str) -> Result<(), String> {
103 if !self.needs_response() {
104 return Ok(());
105 }
106
107 let msg = RpcMessage {
108 res_id: self.req_id.clone(),
109 error: Some(err.to_string()),
110 ..Default::default()
111 };
112
113 let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
114 self.response_tx
115 .send(json)
116 .await
117 .map_err(|e| format!("send error response: {}", e))?;
118
119 self.finalized.store(true, Ordering::SeqCst);
120 Ok(())
121 }
122
123 pub fn finalize(&self) {
125 self.finalized.store(true, Ordering::SeqCst);
126 }
127
128 pub fn is_finalized(&self) -> bool {
130 self.finalized.load(Ordering::SeqCst)
131 }
132}
133
134pub type CommandHandlerFn = Box<dyn Fn(RpcResponseHandler) -> bool + Send + Sync>;
136
137pub struct WshRpc {
145 input_ch: mpsc::Sender<Vec<u8>>,
146 output_ch: mpsc::Sender<Vec<u8>>,
147 rpc_context: Arc<Mutex<Option<RpcContext>>>,
148 auth_token: Arc<Mutex<String>>,
149 rpc_map: Arc<Mutex<HashMap<String, RpcData>>>,
150 event_listener: Arc<EventListener>,
151 debug: AtomicBool,
152 debug_name: String,
153 server_done: AtomicBool,
154}
155
156impl WshRpc {
157 pub fn new(debug_name: &str) -> (Self, mpsc::Receiver<Vec<u8>>, mpsc::Sender<Vec<u8>>) {
159 let (input_tx, _input_rx) = mpsc::channel(DEFAULT_INPUT_CH_SIZE);
160 let (output_tx, output_rx) = mpsc::channel(DEFAULT_OUTPUT_CH_SIZE);
161
162 let input_tx_clone = input_tx.clone();
163 let rpc = Self {
164 input_ch: input_tx,
165 output_ch: output_tx,
166 rpc_context: Arc::new(Mutex::new(None)),
167 auth_token: Arc::new(Mutex::new(String::new())),
168 rpc_map: Arc::new(Mutex::new(HashMap::new())),
169 event_listener: Arc::new(EventListener::new()),
170 debug: AtomicBool::new(false),
171 debug_name: debug_name.to_string(),
172 server_done: AtomicBool::new(false),
173 };
174
175 (rpc, output_rx, input_tx_clone)
176 }
177
178 pub fn set_rpc_context(&self, ctx: RpcContext) {
180 *self.rpc_context.lock().unwrap() = Some(ctx);
181 }
182
183 pub fn get_rpc_context(&self) -> Option<RpcContext> {
185 self.rpc_context.lock().unwrap().clone()
186 }
187
188 pub fn set_auth_token(&self, token: &str) {
190 *self.auth_token.lock().unwrap() = token.to_string();
191 }
192
193 pub fn get_auth_token(&self) -> String {
195 self.auth_token.lock().unwrap().clone()
196 }
197
198 pub fn set_debug(&self, debug: bool) {
200 self.debug.store(debug, Ordering::SeqCst);
201 }
202
203 pub fn get_event_listener(&self) -> &EventListener {
205 &self.event_listener
206 }
207
208 pub async fn send_rpc_request(
210 &self,
211 command: &str,
212 data: Option<Value>,
213 timeout_ms: Option<u64>,
214 ) -> Result<Option<Value>, String> {
215 let req_id = Uuid::new_v4().to_string();
216 let timeout = timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS);
217
218 let (resp_tx, mut resp_rx) = mpsc::channel(RESP_CH_SIZE);
220 self.rpc_map.lock().unwrap().insert(req_id.clone(), RpcData { resp_tx });
221
222 let mut msg = RpcMessage {
224 command: command.to_string(),
225 req_id: req_id.clone(),
226 data,
227 timeout: Some(timeout),
228 ..Default::default()
229 };
230
231 let auth_token = self.get_auth_token();
233 if !auth_token.is_empty() {
234 msg.auth_token = Some(auth_token);
235 }
236
237 let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
239 self.output_ch
240 .send(json)
241 .await
242 .map_err(|e| format!("send request: {}", e))?;
243
244 let result = tokio::time::timeout(
246 std::time::Duration::from_millis(timeout),
247 resp_rx.recv(),
248 )
249 .await;
250
251 self.rpc_map.lock().unwrap().remove(&req_id);
253
254 match result {
255 Ok(Some(resp)) => {
256 if let Some(err) = resp.error {
257 Err(err)
258 } else {
259 Ok(resp.data)
260 }
261 }
262 Ok(None) => Err("response channel closed".to_string()),
263 Err(_) => Err(format!("RPC timeout after {}ms", timeout)),
264 }
265 }
266
267 pub async fn send_message(&self, command: &str, data: Option<Value>) -> Result<(), String> {
269 let mut msg = RpcMessage {
270 command: command.to_string(),
271 req_id: String::new(), data,
273 ..Default::default()
274 };
275
276 let auth_token = self.get_auth_token();
277 if !auth_token.is_empty() {
278 msg.auth_token = Some(auth_token);
279 }
280
281 let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
282 self.output_ch
283 .send(json)
284 .await
285 .map_err(|e| format!("send message: {}", e))
286 }
287
288 pub fn process_incoming_message(&self, raw_msg: &[u8]) -> Result<(), String> {
290 let msg: RpcMessage =
291 serde_json::from_slice(raw_msg).map_err(|e| format!("json decode: {}", e))?;
292
293 if msg.is_response() {
294 self.handle_response(msg)
295 } else if msg.is_request() {
296 tracing::debug!("incoming request: {} ({})", msg.command, msg.req_id);
298 Ok(())
299 } else {
300 Err("message is neither request nor response".to_string())
301 }
302 }
303
304 fn handle_response(&self, msg: RpcMessage) -> Result<(), String> {
306 let rpc_map = self.rpc_map.lock().unwrap();
307 if let Some(rpc_data) = rpc_map.get(&msg.res_id) {
308 let resp = RpcResponse {
309 data: msg.data,
310 error: msg.error,
311 is_final: !msg.cont,
312 };
313 let _ = rpc_data.resp_tx.try_send(resp);
314 } else if self.debug.load(Ordering::SeqCst) {
315 tracing::warn!(
316 "[{}] received response for unknown req_id: {}",
317 self.debug_name,
318 msg.res_id
319 );
320 }
321 Ok(())
322 }
323
324 pub fn is_server_done(&self) -> bool {
326 self.server_done.load(Ordering::SeqCst)
327 }
328
329 pub fn set_server_done(&self) {
331 self.server_done.store(true, Ordering::SeqCst);
332 }
333
334 pub fn pending_count(&self) -> usize {
336 self.rpc_map.lock().unwrap().len()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[tokio::test]
345 async fn test_wshrpc_create() {
346 let (rpc, _output_rx, _input_tx) = WshRpc::new("test");
347 assert!(!rpc.is_server_done());
348 assert_eq!(rpc.pending_count(), 0);
349 assert_eq!(rpc.get_auth_token(), "");
350 }
351
352 #[tokio::test]
353 async fn test_wshrpc_auth_token() {
354 let (rpc, _output_rx, _input_tx) = WshRpc::new("test");
355 rpc.set_auth_token("secret123");
356 assert_eq!(rpc.get_auth_token(), "secret123");
357 }
358
359 #[tokio::test]
360 async fn test_wshrpc_rpc_context() {
361 let (rpc, _output_rx, _input_tx) = WshRpc::new("test");
362 assert!(rpc.get_rpc_context().is_none());
363
364 rpc.set_rpc_context(RpcContext {
365 block_id: "block1".to_string(),
366 tab_id: "tab1".to_string(),
367 conn: "local".to_string(),
368 });
369
370 let ctx = rpc.get_rpc_context().unwrap();
371 assert_eq!(ctx.block_id, "block1");
372 assert_eq!(ctx.tab_id, "tab1");
373 }
374
375 #[tokio::test]
376 async fn test_wshrpc_send_message() {
377 let (rpc, mut output_rx, _input_tx) = WshRpc::new("test");
378 rpc.set_auth_token("token123");
379
380 rpc.send_message("notify", Some(serde_json::json!({"msg": "hello"})))
381 .await
382 .unwrap();
383
384 let raw = output_rx.recv().await.unwrap();
385 let msg: RpcMessage = serde_json::from_slice(&raw).unwrap();
386 assert_eq!(msg.command, "notify");
387 assert!(msg.req_id.is_empty()); assert_eq!(msg.auth_token.unwrap(), "token123");
389 }
390
391 #[tokio::test]
392 async fn test_wshrpc_process_response() {
393 let (rpc, mut output_rx, _input_tx) = WshRpc::new("test");
394
395 let (resp_tx, mut resp_rx) = mpsc::channel(RESP_CH_SIZE);
397 rpc.rpc_map
398 .lock()
399 .unwrap()
400 .insert("req-1".to_string(), RpcData { resp_tx });
401
402 let response = RpcMessage {
404 res_id: "req-1".to_string(),
405 data: Some(serde_json::json!({"result": "success"})),
406 ..Default::default()
407 };
408 let raw = serde_json::to_vec(&response).unwrap();
409 rpc.process_incoming_message(&raw).unwrap();
410
411 let resp = resp_rx.recv().await.unwrap();
413 assert!(resp.error.is_none());
414 assert!(resp.is_final);
415 assert_eq!(
416 resp.data.unwrap(),
417 serde_json::json!({"result": "success"})
418 );
419 }
420}